/tmp/ipykernel_11986/188905684.py:12: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display from IPython.core.display import display, HTML
Performance Analysis¶
We'll start by analyzing our best performing model (V1/baseline).
After ~150-200 we've selected this configuration:
| Parameter | Value |
|---|---|
| anneal_strategy | "cos" |
| base_lr | 0.0068893981577029285 |
| batch_size | 256 |
| div_factor | 24 |
| dropout | 0.1 |
| final_div_factor | 2,873 |
| freeze_epochs | 0 |
| gender_loss_weight | 0.9 |
| l1_lambda | 0.0001 |
| lr_scheduler | "one_cycle" |
| max_lr | 0.012321315111072404 |
| model_type | "mobilenet_v3_small" |
| num_epochs | 18 |
| override_cycle_epoch_count | 15 |
| pct_start | 0.36685557351085574 |
| prefix | "fixed_samples_final_full_split_15_cycle+3" |
| train_path | "dataset/train_8_folds_first" |
| use_dynamic_augmentation | false |
| val_path | "dataset/test_2_folds_last" |
| weight_decay | 0.00019323262043373016 |
Main parameters:
OneCyclewithCosine Annealingachieved considerably faster convergence and better generalization. Only 15-20 epochs with 256 batch size were needed to achieve optimal performance compared to step, decay or reduce on plateau schedulers (25-35+ epochs).AdamW was used as an optimizer.
In addition, we've used L1 regularization (AdamW already has L2 builtin) and dropout (only applied to our final classifier/regression layers) to reduce overfitting. We've observed a relative small impact on validation/training performance with the UTK dataset however it theoretically have a bigger impact in real-world/production data.
best performing scheduler {explain which}
optimal set of transformations/augmentations {explain which}
{we've arrived at this configuration by only trying to maximize the high level model performance:
- {total weighted loss (combined from normalized gender and age prediction loss}
- gender predicitons accuracy
- MAE for age predictions
| Parameter | Value | |
|---|---|---|
| 0 | model_type | mobilenet_v3_small |
| 1 | lr_scheduler | one_cycle |
| 2 | anneal_strategy | cos |
| 3 | base_lr | 0.006889 |
| 4 | batch_size | 256 |
| 5 | div_factor | 24 |
| 6 | dropout | 0.1 |
| 7 | final_div_factor | 2873 |
| 8 | freeze_epochs | 0 |
| 9 | l1_lambda | 0.0001 |
| 10 | max_lr | 0.012321 |
| 11 | num_epochs | 18 |
| 12 | override_cycle_epoch_count | 15 |
| 13 | weight_decay | 0.000193 |
| 14 | pct_start | 0.366856 |
| 15 | train_path | dataset/train_8_folds_first |
| 16 | val_path | dataset/test_2_folds_last |
Main Observations¶
using
one_cycleas our LR scheduler has allowed us to achieve convergence in only ~15 epochs while providing signficantly better performance thanreduce_on_plateauorstep_lrwere able to achieve even after 30-40 epochs.freeze_epochsModel was fine-tuned using pretrained weights (
IMAGENET1K_V1). We've found that training MobileNet from scratch (using randomized initial weights) can provide comparable or only slightly inferior performance with the UTK dataset. We've still chosen to use the pretrained weights because:- the model still performs a bit better (0.015 higher accuracy, ~0.2 lower MAE)
- because the model was trained with a higher variety of images in different condition the model should still perform better (or not worse) on images of faces in real-world conditions.
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Small_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg)
using model_type = mobilenet_v3_small || pretrained = True || using model_type = mobilenet_v3_small || pretrained = True ||
AgeGenderClassifier(
(base_model): Sequential(
(0): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
(1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(2): Conv2dNormActivation(
(0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(2): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(16, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(72, 72, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=72, bias=False)
(1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(2): Conv2dNormActivation(
(0): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(3): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(24, 88, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(88, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(88, 88, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=88, bias=False)
(1): BatchNorm2d(88, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(2): Conv2dNormActivation(
(0): Conv2d(88, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(4): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(96, 96, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=96, bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(96, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(5): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
(1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(240, 64, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(64, 240, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(6): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
(1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(240, 64, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(64, 240, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(7): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
(1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(120, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(8): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(144, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(144, 144, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=144, bias=False)
(1): BatchNorm2d(144, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(40, 144, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(144, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(9): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(288, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(288, 288, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=288, bias=False)
(1): BatchNorm2d(288, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(288, 72, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(72, 288, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(10): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(576, 144, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(144, 576, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(11): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(576, 144, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(144, 576, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(12): Conv2dNormActivation(
(0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
)
(1): AdaptiveAvgPool2d(output_size=1)
)
(global_pool): AdaptiveAvgPool2d(output_size=1)
(gender_classifier): Sequential(
(0): Dropout(p=0.1, inplace=False)
(1): Linear(in_features=576, out_features=2, bias=True)
)
(age_regressor): Sequential(
(0): Dropout(p=0.1, inplace=False)
(1): Linear(in_features=576, out_features=1, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
(gender_loss): CrossEntropyLoss()
(age_loss): L1Loss()
(gender_accuracy): BinaryAccuracy()
(train_gender_accuracy): BinaryAccuracy()
(age_mae): MeanAbsoluteError()
)
Sample images found: 4740
Initial age distribution:
Bin 0-9: 608
Bin 10-19: 263
Bin 20-29: 1525
Bin 30-39: 949
Bin 40-49: 408
Bin 50-59: 451
Bin 60-69: 249
Bin 70-79: 155
Bin 80-89: 132
Gender distribution: {0: 2387, 1: 2353}
Age distribution: [(26, 445), (1, 237), (28, 202), (24, 184), (30, 173), (35, 171), (25, 151), (29, 135), (32, 125), (27, 122), (2, 94), (36, 89), (40, 89), (23, 86), (45, 82), (22, 81), (50, 80), (31, 78), (54, 77), (38, 69), (21, 68), (34, 68), (37, 68), (39, 67), (3, 62), (8, 57), (55, 55), (58, 55), (52, 53), (20, 51), (4, 51), (65, 50), (60, 49), (42, 47), (18, 43), (33, 41), (46, 41), (15, 38), (56, 38), (16, 36), (53, 36), (61, 35), (9, 34), (5, 33), (80, 30), (10, 29), (17, 29), (49, 29), (70, 29), (75, 28), (41, 27), (48, 27), (47, 26), (51, 26), (85, 25), (12, 23), (62, 23), (63, 23), (43, 21), (72, 21), (14, 20), (19, 20), (6, 20), (7, 20), (44, 19), (73, 18), (76, 18), (90, 18), (66, 17), (13, 16), (57, 16), (59, 15), (67, 14), (78, 14), (68, 13), (69, 13), (64, 12), (11, 9), (79, 9), (82, 7), (84, 7), (89, 7), (71, 6), (74, 6), (77, 6), (86, 6), (88, 6), (81, 4), (99, 4), (87, 3), (95, 3), (96, 3), (116, 2), (92, 2), (100, 1), (105, 1), (111, 1), (83, 1), (93, 1)]
Total samples: 4740
Valid images: 4740
Gender distribution: {0: 2387, 1: 2353}
Age range: 1 - 116
Processing batches: 100%|██████████| 149/149 [00:04<00:00, 31.38batch/s]
Sample images found: 4740
Initial age distribution:
Bin 0-9: 608
Bin 10-19: 263
Bin 20-29: 1525
Bin 30-39: 949
Bin 40-49: 408
Bin 50-59: 451
Bin 60-69: 249
Bin 70-79: 155
Bin 80-89: 132
Gender distribution: {0: 2387, 1: 2353}
Age distribution: [(26, 445), (1, 237), (28, 202), (24, 184), (30, 173), (35, 171), (25, 151), (29, 135), (32, 125), (27, 122), (2, 94), (36, 89), (40, 89), (23, 86), (45, 82), (22, 81), (50, 80), (31, 78), (54, 77), (38, 69), (21, 68), (34, 68), (37, 68), (39, 67), (3, 62), (8, 57), (55, 55), (58, 55), (52, 53), (20, 51), (4, 51), (65, 50), (60, 49), (42, 47), (18, 43), (33, 41), (46, 41), (15, 38), (56, 38), (16, 36), (53, 36), (61, 35), (9, 34), (5, 33), (80, 30), (10, 29), (17, 29), (49, 29), (70, 29), (75, 28), (41, 27), (48, 27), (47, 26), (51, 26), (85, 25), (12, 23), (62, 23), (63, 23), (43, 21), (72, 21), (14, 20), (19, 20), (6, 20), (7, 20), (44, 19), (73, 18), (76, 18), (90, 18), (66, 17), (13, 16), (57, 16), (59, 15), (67, 14), (78, 14), (68, 13), (69, 13), (64, 12), (11, 9), (79, 9), (82, 7), (84, 7), (89, 7), (71, 6), (74, 6), (77, 6), (86, 6), (88, 6), (81, 4), (99, 4), (87, 3), (95, 3), (96, 3), (116, 2), (92, 2), (100, 1), (105, 1), (111, 1), (83, 1), (93, 1)]
Total samples: 4740
Valid images: 4740
Gender distribution: {0: 2387, 1: 2353}
Age range: 1 - 116
Processing batches: 100%|██████████| 149/149 [00:04<00:00, 37.06batch/s]
Total predictions: 4740 Total image quality data: 23086 Matched data points: 4615 Total predictions: 4740 Total image quality data: 23086 Matched data points: 4615
''
Performance¶
| Female | Male | Overall | |
|---|---|---|---|
| Support | 2353.000 | 2387.000 | 4740.000 |
| Accuracy | 0.931 | 0.931 | 0.931 |
| Precision | 0.924 | 0.938 | 0.931 |
| Recall | 0.938 | 0.924 | 0.931 |
| F1-score | 0.931 | 0.931 | 0.931 |
| AUC-ROC | NaN | NaN | 0.981 |
| PR-AUC | NaN | NaN | 0.978 |
| Log Loss | NaN | NaN | 0.179 |
| Brier Score | NaN | NaN | NaN |
| Value | |
|---|---|
| MAE | 5.105901 |
| MSE | 54.144762 |
| RMSE | 7.358312 |
| R-squared | 0.862191 |
| MAPE | 25.161557 |
We've been able to achieve an accuracy of ~93% for gender predictions and Age MAE (Mean Absolute Error) of around 5.1 years.
| Model | Age Estimation (MAE) | Gender Classification (Accuracy) | |
|---|---|---|---|
| 0 | XGBoost (+feat. extraction) | 5.89 | 93.80 |
| 1 | SVC(..) | 5.49 | 94.64 |
| 2 | VGG_f | 4.86 | 93.42 |
| 3 | ResNet50_f | 4.65 | 94.64 |
| 4 | SENet50_f | 4.58 | 94.90 |
(*https://arxiv.org/pdf/2110.12633)
While our model still lags
*https://arxiv.org/pdf/2110.12633
That seems like a reasonable good results when using such a small model directly (i.e. no ensemble/metal-models).
(Specific dataset split and preprocessing)
| VGG16 | ResNet50 | MobileNetV3-Small | |
|---|---|---|---|
| Metric | |||
| Parameter Count | ~138 million | ~25.6 million | ~2.5 million |
| Model Size (PyTorch, FP32) | ~528 MB | ~98 MB | ~10 MB |
| Inference Speed (relative) | 1x (baseline) | ~2.5x faster | ~10x faster |
| FLOPs | ~15.5 billion | ~4.1 billion | ~56 million |
| Approx. Memory Usage (inference) | 1x | ~0.6x | ~0.15x |
Overall this is not necessarily particularly exceptional, the UTK Face dataset is relatively small and specific compared to general image classification tasks (which effectively can level the playing field for smaller models) and there are several other studies/benchmarks showing that show MobileNet variants performing competitively with larger models on simple task like this (while performing signficantly worse at more compelx tasks like emotion detecting or face recognition):
e.g. according to Savchenko, A. V. (2024). arXiv. https://ar5iv.labs.arxiv.org/html/2103.17107 MobileNet without any fine-tuning using the UTKFace dataset (i.e. and full UTKFace was used for testing) actually outperformed VGG-16 & ResNet-50.
Age Classification¶
<Axes: title={'center': 'Confusion Matrix with Percentage Accuracy'}, xlabel='Predicted label', ylabel='True label'>
Accuracy of Gender Prediction by Age Group¶
| Total | Correct | Accuracy | |
|---|---|---|---|
| Age_Group | |||
| 0-4 | 444 | 307 | 0.6914 |
| 4-14 | 261 | 215 | 0.8238 |
| 14-24 | 636 | 604 | 0.9497 |
| 24-30 | 1228 | 1187 | 0.9666 |
| 30-40 | 865 | 837 | 0.9676 |
| 40-50 | 399 | 393 | 0.9850 |
| 50-60 | 420 | 409 | 0.9738 |
| 60-70 | 229 | 218 | 0.9520 |
| 70-80 | 156 | 149 | 0.9551 |
| 80+ | 102 | 94 | 0.9216 |
We can see that gender prediction accuracy is reasonably high across all ranges except young children. Realistically it's unlikely we can do anything about that, facial features of babies tend to be very different from adults. Potentially it might be worth investigating building a separate model for them but it's unlikely that it would achieved very high performance either.
Summary of Age Prediction¶
| True Age | Predicted Age | |
|---|---|---|
| Mean | 33.308439 | 32.147823 |
| Median | 29.000000 | 28.514690 |
| Min | 1.000000 | -2.139822 |
| Max | 116.000000 | 95.214233 |
Age Prediction by Age Group¶
| Age_Group | Support | Age_MAE | Age_MSE | Age_RMSE | Age_R-squared | Age_MAPE | |
|---|---|---|---|---|---|---|---|
| 0 | 0-4 | 444 | 1.588580 | 11.325658 | 3.365361 | -9.241579 | 99.745904 |
| 1 | 4-14 | 261 | 4.011655 | 34.033093 | 5.833789 | -3.743251 | 46.700869 |
| 2 | 14-24 | 636 | 4.171022 | 32.965802 | 5.741585 | -2.937213 | 21.156784 |
| 3 | 24-30 | 1228 | 3.720786 | 30.006521 | 5.477821 | -10.167695 | 13.674633 |
| 4 | 30-40 | 865 | 6.270144 | 63.924114 | 7.995256 | -7.162335 | 17.644973 |
| 5 | 40-50 | 399 | 7.749943 | 96.742555 | 9.835779 | -10.194667 | 16.942367 |
| 6 | 50-60 | 420 | 7.311122 | 91.486462 | 9.564856 | -11.248783 | 13.271226 |
| 7 | 60-70 | 229 | 6.725516 | 80.393407 | 8.966237 | -8.236708 | 10.369088 |
| 8 | 70-80 | 156 | 7.617475 | 105.892985 | 10.290432 | -11.530508 | 10.082188 |
| 9 | 80+ | 102 | 8.947648 | 173.258202 | 13.162758 | -3.118748 | 9.777900 |
This table shows one of the flaws of using MAE are our target metric, it downplays inaccurate predictions for children and potential exaggerates them as the subject age increases.
i.e. miss-classifying a newborn as a 5-year-old child or the other way around is much bigger error than doing the same when the subject is over 70.
MAPE (Mean Absolute Percentage Error) would pontetially be a better metric, however it can (and clearly is) be problematic for very young ages (near zero) as it leads to extremely large or undefined percentages.
<Axes: title={'center': 'Confusion Matrix with Percentage Accuracy'}, xlabel='Predicted label', ylabel='True label'>
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:170: FutureWarning:
`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.
sns.kdeplot(x='True_Age', y='Error', data=df, ax=axs[0, 0], cmap="YlOrRd", shade=True, cbar=True)
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:182: FutureWarning:
`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.
sns.kdeplot(x='True_Age', y='Predicted_Age', data=df, ax=axs[1, 0], cmap="YlOrRd", shade=True, cbar=True)
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:209: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
age_group_stats = df.groupby('Age_Group')['Absolute_Error'].agg([
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:244: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
axs[2, 1].set_xticklabels(new_labels)
LIME¶
Solving Age Balancing¶
<module 'Notebooks.utils.error_analysis' from '/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py'>
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
Figure size: 840x2240 px
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
0%| | 0/500 [00:00<?, ?it/s]
['dataset/test_2_folds_last/111_1_0_20170120134646399.jpg.chip.jpg', 'dataset/test_2_folds_last/1_1_0_20170109194452834.jpg.chip.jpg', 'dataset/test_2_folds_last/9_0_0_20170110225030430.jpg.chip.jpg', 'dataset/test_2_folds_last/8_0_1_20170114025855492.jpg.chip.jpg', 'dataset/test_2_folds_last/41_1_1_20170117021604893.jpg.chip.jpg']
Most Misclassified Images (both gender/age)¶
Figure size: 840x1400 px
Figure size: 840x1400 px
Misclassified Gender¶
Looking at gender specifically it's actually likely that our model performs better than the summarized results might imply.
The images above showcases where out model was least accurate, and we can see that all except one are likely cases of data being mislabeled in the original dataset (OR it's labeled accurately based on those individuals self-identity)
Figure size: 840x1960 px
We can see two main issues:
Some images are poor quality or are strongly cropped. It's possible that we can solve this problem by using heuristics in preprocessing to exclude these samples from trained and test samples.
We can see certain patterns related to race and age. The model is having issue classifying face of people who are non-white, possibly due to different facial features or skin color (although grayscale transform should partially fix that). Also, it's struggling with either very old people or children/babies possibly because of too small sample size and relatively more "androgynous" facial features in those groups. We'll attempt to fix this using augmentation in combination with oversampling (i.e. we'll use transforms to create additional samples for age bins which are underrepresented, additionally we'll use some of the color analysis from the EDA to also oversample the images of under-represented skin colors)
Many samples are potentially mislabeled. It's possible that some of the samples are of people who self-identify as male/female while still retaining facial features, hairstyles etc. of the opposite gender. Or they are just mislabeled. In either case this part would be the hardest to solve.
Filtering Out "Invalid" Samples¶
We'l use a mix of metrics to try and determine which images are very poor quality, lack enough details to proper classification etc. :
BRISQUE (Blind/Referenceless Image Spatial Quality Evaluator):
A no-reference image quality assessment method. Uses scene statistics of locally normalized luminance coefficients to quantify possible losses of "naturalness" in the image due to distortions. Operates in the spatial domain.
Laplacian Variance:
A measure of image sharpness/blurriness. Uses the Laplacian operator to compute the second derivative of the image. Measures the variance of the Laplacian-filtered image.
FFT-based Blur Detection:
Uses Fast Fourier Transform to analyze the frequency components of an image. Applies a high-pass filter in the frequency domain and measures the remaining energy.
See the Data Analysis notebook for more details.
BRISQUE + Laplacian Variance¶
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[37], line 1 ----> 1 worst_quality_images = error_analysis.get_worst_quality_images(merged_data) 2 worst_quality_images NameError: name 'merged_data' is not defined
One obvious major shortcoming of this approach is that we're basically excluding a significant proportion of samples basically just because our model performs very poorly on them.
While {TODO}
A production pipeline might be:
- Check if image is valid using heuristics (e.g. telling the user to position the camera better etc.)
Augmentation Based Oversampling¶
We'll use augmentation/transforms combined with oversampling to increase the number of samples in underrepresented classes. This approach:
- allows us to preserve original data characteristics while introducing variability
Potential issues:
- Risk of overfitting to augmented versions of underrepresented samples
- Possibility of introducing unintended biases if augmentation isn't carefully balanced
- May not fully address underlying dataset biases
- Requires careful monitoring to ensure improved performance across all age groups
Comparing Both Models¶
Let's look at samples that were miss-classified using the initial model but are now correct in the new model:
Of course, we have specifically selected the best case examples (i.e. where the performance of model has improved the most) which probably gives a much to optimistic picture of the overall improvement (relative to overal increase in accuracy/MAE which is not as signficant).
Instead, we've selected some of the samples our initial model failed on that were unlikely to be mislabeled: